import os
import sys
import glob
import numpy as np

import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image

import sys
sys.path.append('../../')
import modules.tune.dsine.utils.utils as utils
import modules.tune.dsine.projects.dsine.config as config

args = config.get_args(test=True)

args.ckpt_path = './modules/tune/dsine/projects/dsine/checkpoints/exp001_cvpr2024/dsine.pt'

print(f"args.ckpt_path", args.ckpt_path)
assert os.path.exists(args.ckpt_path)

from modules.tune.dsine.models.dsine.v02 import DSINE_v02 as DSINE

device = "cuda:0"

model = DSINE(args).to(device)
model = utils.load_checkpoint(args.ckpt_path, model)
model.eval()




def intrins_from_fov(new_fov, H, W, dtype=torch.float32, device='cpu'):
    """Define intrinsics matrix based on field-of-view (FOV).
        Principal point is assumed to be at the center.

        NOTE: new_fov should be in degrees.
        NOTE: Top-left is (0, 0).
    """

    new_fov_rad = torch.deg2rad(torch.tensor(new_fov, dtype=dtype, device=device))


    new_fx = new_fy = (torch.max(torch.tensor(H, dtype=dtype, device=device), 
                                  torch.tensor(W, dtype=dtype, device=device)) / 2.0) / torch.tan(new_fov_rad / 2.0)


    new_cx = (torch.tensor(W, dtype=dtype, device=device) / 2.0) - 0.5
    new_cy = (torch.tensor(H, dtype=dtype, device=device) / 2.0) - 0.5

    new_intrins = torch.tensor([
        [new_fx, 0,       new_cx],
        [0,      new_fy,  new_cy],
        [0,      0,       1     ]
    ], dtype=dtype, device=device)

    return new_intrins
    
def calculate_intrins(fovx, fovy, H, W, dtype=torch.float32, device='cpu'):
    """Define intrinsics matrix based on field-of-view (FOV).
        Principal point is assumed to be at the center.

        NOTE: new_fov should be in degrees.
        NOTE: Top-left is (0, 0).
    """

    new_fovx_rad = torch.deg2rad(torch.tensor(fovx, dtype=dtype, device=device))
    
    new_fovy_rad = torch.deg2rad(torch.tensor(fovy, dtype=dtype, device=device))


    new_fx = (torch.max(torch.tensor(H, dtype=dtype, device=device), 
                                  torch.tensor(W, dtype=dtype, device=device)) / 2.0) / torch.tan(new_fovx_rad / 2.0)
    
    new_fy = (torch.max(torch.tensor(W, dtype=dtype, device=device), 
                                  torch.tensor(H, dtype=dtype, device=device)) / 2.0) / torch.tan(new_fovy_rad / 2.0)


    new_cx = (torch.tensor(W, dtype=dtype, device=device) / 2.0) - 0.5
    new_cy = (torch.tensor(H, dtype=dtype, device=device) / 2.0) - 0.5

    new_intrins = torch.tensor([
        [new_fx, 0,       new_cx],
        [0,      new_fy,  new_cy],
        [0,      0,       1     ]
    ], dtype=dtype, device=device)

    return new_intrins


def normal_estimation(img : torch.Tensor, intrins : torch.Tensor):

    device = img.device

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])


    img = img.unsqueeze(0).to(device)

    # pad input
    _, _, orig_H, orig_W = img.shape
    lrtb = utils.get_padding(orig_H, orig_W)
    img = F.pad(img, lrtb, mode="constant", value=0.0)
    img = normalize(img)

    # get intrinsics


    # NOTE: if intrins is not given, we just assume that the principal point is at the center
    # and that the field-of-view is 60 degrees (feel free to modify this assumption)
    #intrins = intrins_from_fov(new_fov=60.0, H=orig_H, W=orig_W, device=device).unsqueeze(0)
    
    intrins = intrins.unsqueeze(0)
    
    intrins[:, 0, 2] += lrtb[0]
    intrins[:, 1, 2] += lrtb[2]

    pred_norm = model(img, intrins=intrins)[-1]
    pred_norm = pred_norm[:, :, lrtb[2]:lrtb[2]+orig_H, lrtb[0]:lrtb[0]+orig_W]
    
    return pred_norm


